from typing import Dict, Callable, Sequence

from einops import rearrange
import torch
from torch import nn
from torch.utils.data import Dataset
from collections import defaultdict

from dataset.cyclone import CycloneSample
from eval.complex_metrics import ComplexMetrics


def get_rollout_fn(
    n_steps: int,
    bundle_steps: int,
    dataset: Dataset,
    use_amp: bool = False,
    use_bf16: bool = False,
    device: str = "cuda",
) -> Callable:
    # correct step size by adding last bundle
    # n_steps_ = n_steps + bundle_steps - 1

    def _rollout(
        model: nn.Module,
        inputs: Dict,
        idx_data: Dict,
        conds: Dict,
    ) -> torch.Tensor:
        # cap the steps depending on the current max timestep
        rollout_steps = []
        for i, f_idx in enumerate(idx_data["file_index"].tolist()):
            ts_left = dataset.num_ts(int(f_idx)) - int(idx_data["timestep_index"][i])
            ts_left = ts_left // bundle_steps - 1
            rollout_steps.append(min(ts_left, n_steps))
        rollout_steps = min(rollout_steps)

        tot_ts = rollout_steps * bundle_steps
        inputs_t = inputs.copy()
        preds = defaultdict(list)
        # get corresponding timesteps
        ts_step = bundle_steps
        ts_idxs = [
            list(range(int(ts), int(ts) + tot_ts, ts_step))
            for ts in idx_data["timestep_index"].tolist()
        ]
        fluxes = []
        tsteps = dataset.get_timesteps(idx_data["file_index"], torch.tensor(ts_idxs))
        amp_dtype = torch.bfloat16 if use_bf16 else torch.float16
        with torch.no_grad():
            # move bundles forward, rollout in blocks
            for i in range(0, rollout_steps):
                with torch.autocast(device, dtype=amp_dtype, enabled=use_amp):
                    conds["timestep"] = tsteps[:, i].to(device)
                    pred = model(**inputs_t, **conds)
                    if "flux" in pred:
                        fluxes.append(pred["flux"].cpu())
                        del pred["flux"]

                    for key in inputs_t.keys():
                        inputs_t[key] = pred[key].clone().float()

                for key in pred:
                    # add time dim if not there
                    preds[key].append(
                        pred[key].cpu().unsqueeze(2)
                        if pred[key].ndim in [4, 5, 7]
                        else pred[key].cpu()
                    )

        for key in preds.keys():
            # only return desired size
            preds[key] = rearrange(torch.cat(preds[key], 2), "b c t ... -> t b c ...")
            preds[key] = preds[key][: rollout_steps * bundle_steps, :, ...]
        if len(fluxes) > 0:
            preds["flux"] = rearrange(torch.stack(fluxes, dim=-1), "b t -> t b")
        # to float32 for integrals etc
        # TODO is this the best approach?
        preds = {k: p.to(dtype=torch.float32) for k, p in preds.items()}
        return preds

    return _rollout


def validation_metrics(
    stage: str,
    tgts: Dict[str, torch.Tensor],
    preds: Dict[str, torch.Tensor],
    geometry: Dict[str, torch.Tensor],
    loss_wrap: nn.Module,
    eval_integrals: bool = True,
    integral_loss_type: str = "mse",  # loss for phi_integral
):
    # VAE/VQ-VAE models add extra keys like mu, logvar, vq_commit_loss that do not exist in tgts
    common_keys = set(preds.keys()) & set(tgts.keys())
    for k in common_keys:
        assert tgts[k].shape == preds[k].shape, f"Mismatch in shapes for {k}"
    _, metrics, integrated, loss_stats = loss_wrap(
        preds=preds,
        tgts=tgts,
        geometry=geometry,
        compute_integrals=eval_integrals,
        integral_loss_type=integral_loss_type,
    )
    # complex metrics for df
    if "df" in preds and "df" in tgts:
        try:
            complex_metrics = ComplexMetrics()

            complex_results = complex_metrics.evaluate_all(
                preds["df"],
                tgts["df"],
            )

            # Add each complex metric to the metrics dictionary
            for key, value in complex_results.items():
                metrics[f"complex_{key}"] = torch.tensor(value, dtype=torch.float32)

        except Exception as e:
            print(f"Warning: Failed to compute complex metrics: {e}")

    # TODO(diff) what 5d metrics can we include?
    return metrics, integrated
